use indexmap::IndexSet;
use toasty_core::stmt::{self, visit_mut, Condition};

use crate::engine::{
    eval, hir,
    index::{self, IndexPlan},
    mir,
    plan::HirPlanner,
};

struct Selection {
    columns: IndexSet<stmt::ExprReference>,
    returning: Option<stmt::Returning>,
}

struct LinkedStatement {
    stmt: stmt::Statement,
    inputs: IndexSet<mir::NodeId>,
}

struct PlanStatement<'a, 'b> {
    planner: &'a mut HirPlanner<'b>,
    stmt_id: hir::StmtId,
    stmt_info: &'b hir::StatementInfo,

    /// True if the statement's dependencies have been tracked
    did_take_deps: bool,
}

impl HirPlanner<'_> {
    pub(super) fn plan_statement(&mut self, stmt_id: hir::StmtId) {
        let stmt_info = &self.hir[stmt_id];

        // Check if the statement has already been planned
        if stmt_info.exec_statement.get().is_some() {
            return;
        }

        // First, plan dependency statements. These are statments that must run
        // before the current one but do not reference the current statement.
        for &dep_stmt_id in &stmt_info.deps {
            self.plan_statement(dep_stmt_id);
        }

        // Delegate to PlanStatement
        let mut planner = PlanStatement {
            planner: self,
            stmt_id,
            stmt_info,
            did_take_deps: false,
        };
        planner.plan();
    }
}

impl<'a, 'b> PlanStatement<'a, 'b> {
    // ===== Entry point =====

    fn plan(&mut self) {
        let mut stmt = self.stmt_info.stmt.as_deref().unwrap().clone();

        // Tracks if the original query is a single query.
        let single = stmt.as_query().map(|query| query.single).unwrap_or(false);
        if let Some(query) = stmt.as_query_mut() {
            query.single = false;
        }

        let returning = stmt.take_returning();

        // Columns to select
        let columns = IndexSet::new();

        let mut linked_stmt = LinkedStatement {
            stmt,
            inputs: IndexSet::new(),
        };

        let mut selection = Selection { columns, returning };

        // Visit the main statement's returning clause to extract needed columns
        self.extract_columns_from_returning(&mut selection, &mut linked_stmt);

        // Track sub-statement arguments from filter
        self.extract_sub_statement_args_from_filter(&mut linked_stmt);

        // For each back ref, include the needed columns
        self.collect_back_ref_columns(&mut selection);

        // If there are any ref args, then the statement needs to be rewritten
        // to batch load all records for a NestedMerge operation .
        let ref_source = self.process_ref_args(&mut linked_stmt);

        let exec_stmt_node_id = self.plan_execution(linked_stmt, &mut selection, ref_source);

        // Track the exec statement operation node.
        self.stmt_info.exec_statement.set(Some(exec_stmt_node_id));

        // Now, for each back ref, we need to project the expression to what the
        // next statement expects.
        self.process_back_ref_projections(exec_stmt_node_id, &selection);

        // Track the selection for later use.
        self.stmt_info
            .exec_statement_selection
            .set(selection.columns)
            .unwrap();

        // Plan each child
        self.plan_child_statements();

        // Plans a NestedMerge if one is needed
        let output_node_id =
            self.plan_output_node(exec_stmt_node_id, selection.returning, single, ref_source);

        self.stmt_info.output.set(Some(output_node_id));
    }

    // ===== Setup helpers =====

    fn extract_columns_from_returning(
        &mut self,
        selection: &mut Selection,
        linked_stmt: &mut LinkedStatement,
    ) {
        visit_mut::for_each_expr_mut(&mut selection.returning, |expr| {
            match expr {
                stmt::Expr::Reference(expr_reference) => {
                    let (index, _) = selection.columns.insert_full(*expr_reference);
                    *expr = stmt::Expr::arg_project(0, [index]);
                }
                stmt::Expr::Arg(expr_arg) => match &self.stmt_info.args[expr_arg.position] {
                    hir::Arg::Ref { .. } => {
                        todo!("refs in returning is not yet supported");
                    }
                    hir::Arg::Sub {
                        stmt_id,
                        input,
                        returning: true,
                    } => {
                        // If there are back-refs, the exec statement is preloading
                        // data for a NestedMerge. Sub-statements will be loaded
                        // during the NestedMerge.
                        if !self.stmt_info.back_refs.is_empty() {
                            return;
                        }

                        let node_id = self.planner.hir[stmt_id].exec_statement.get().expect("bug");

                        let (index, _) = linked_stmt.inputs.insert_full(node_id);
                        input.set(Some(index));
                    }
                    _ => todo!(),
                },
                _ => {}
            }
        });
    }

    fn extract_sub_statement_args_from_filter(&mut self, linked_stmt: &mut LinkedStatement) {
        visit_mut::for_each_expr_mut(&mut linked_stmt.stmt.filter_mut(), |expr| {
            if let stmt::Expr::Arg(expr_arg) = expr {
                if let hir::Arg::Sub {
                    stmt_id: arg_stmt_id,
                    returning: false,
                    input,
                } = &self.stmt_info.args[expr_arg.position]
                {
                    debug_assert!(!self.planner.engine.capability().sql);
                    debug_assert!(input.get().is_none());
                    let node_id = self.planner.hir[arg_stmt_id].output.get().expect("bug");

                    let (index, _) = linked_stmt.inputs.insert_full(node_id);
                    input.set(Some(index));
                }
            }
        });
    }

    fn collect_back_ref_columns(&mut self, selection: &mut Selection) {
        for back_ref in self.stmt_info.back_refs.values() {
            for expr in &back_ref.exprs {
                selection.columns.insert(*expr);
            }
        }
    }

    fn process_ref_args(&mut self, linked_stmt: &mut LinkedStatement) -> Option<stmt::ExprArg> {
        let mut ref_source = None;

        for arg in &self.stmt_info.args {
            let hir::Arg::Ref {
                stmt_id: target_id,
                input,
                ..
            } = arg
            else {
                continue;
            };

            assert!(ref_source.is_none(), "TODO: handle more complex ref cases");
            assert!(
                !linked_stmt.stmt.filter_or_default().is_false(),
                "TODO: handle const false filters"
            );

            // Find the back-ref for this arg
            let node_id = self.planner.hir[target_id].back_refs[&self.stmt_id]
                .node_id
                .get()
                .unwrap();

            let (index, _) = linked_stmt.inputs.insert_full(node_id);
            ref_source = Some(stmt::ExprArg::new(index));
            input.set(Some(0));
        }

        if let Some(ref_source) = ref_source {
            self.rewrite_stmt_for_ref_source(&mut linked_stmt.stmt, ref_source);
        }

        ref_source
    }

    fn rewrite_stmt_for_ref_source(
        &mut self,
        stmt: &mut stmt::Statement,
        ref_source: stmt::ExprArg,
    ) {
        if self.planner.engine.capability().sql {
            self.rewrite_stmt_for_ref_source_sql(stmt, ref_source);
        } else {
            self.rewrite_stmt_for_ref_source_nosql(stmt, ref_source);
        }
    }

    fn rewrite_stmt_for_ref_source_sql(
        &mut self,
        stmt: &mut stmt::Statement,
        ref_source: stmt::ExprArg,
    ) {
        // If targeting SQL, leverage the SQL query engine to handle most of the rewrite details.
        let mut filter = stmt
            .filter_mut()
            .map(|filter| filter.take())
            .unwrap_or_default();

        visit_mut::for_each_expr_mut(&mut filter, |expr| {
            match expr {
                stmt::Expr::Reference(stmt::ExprReference::Column(expr_column)) => {
                    debug_assert_eq!(0, expr_column.nesting);
                    // We need to up the nesting to reflect that the filter is moved
                    // one level deeper.
                    expr_column.nesting += 1;
                }
                stmt::Expr::Arg(expr_arg) => {
                    let hir::Arg::Ref {
                        input,
                        batch_load_index: index,
                        ..
                    } = &self.stmt_info.args[expr_arg.position]
                    else {
                        todo!()
                    };

                    // Rewrite reference the new `FROM`.
                    *expr = stmt::Expr::column(stmt::ExprColumn {
                        nesting: 0,
                        table: input.get().unwrap(),
                        column: *index,
                    });
                }
                _ => {}
            }
        });

        let sub_query = stmt::Select {
            returning: stmt::Returning::Expr(stmt::Expr::record([1])),
            source: stmt::Source::from(ref_source),
            filter,
        };

        stmt.filter_mut_unwrap().set(stmt::Expr::exists(sub_query));
    }

    fn rewrite_stmt_for_ref_source_nosql(
        &mut self,
        stmt: &mut stmt::Statement,
        ref_source: stmt::ExprArg,
    ) {
        let mut filter = stmt.filter_expr_mut();
        visit_mut::for_each_expr_mut(&mut filter, |expr| match expr {
            stmt::Expr::Reference(stmt::ExprReference::Column(expr_column)) => {
                debug_assert_eq!(0, expr_column.nesting);
            }
            stmt::Expr::Arg(expr_arg) => {
                let hir::Arg::Ref {
                    batch_load_index: index,
                    ..
                } = &self.stmt_info.args[expr_arg.position]
                else {
                    todo!()
                };

                *expr = stmt::Expr::arg(*index);
            }
            _ => {}
        });

        if let Some(filter) = filter {
            let expr = filter.take();
            *filter = stmt::Expr::any(stmt::Expr::map(ref_source, expr));
        }
    }

    // ===== Execution dispatch =====

    fn plan_execution(
        &mut self,
        linked: LinkedStatement,
        selection: &mut Selection,
        ref_source: Option<stmt::ExprArg>,
    ) -> mir::NodeId {
        if let Some(node_id) = self.plan_const_or_empty_statement(&linked, selection) {
            node_id
        } else if self.planner.engine.capability().sql || linked.stmt.is_insert() {
            self.plan_sql_execution(linked, selection)
        } else {
            self.plan_nosql_execution(linked, selection, ref_source)
        }
    }

    fn plan_const_or_empty_statement(
        &mut self,
        linked: &LinkedStatement,
        selection: &Selection,
    ) -> Option<mir::NodeId> {
        if linked.stmt.is_const() {
            let stmt::Value::List(rows) = linked.stmt.eval_const().unwrap() else {
                todo!()
            };

            return Some(
                self.insert_const(
                    rows,
                    self.planner
                        .engine
                        .infer_record_list_ty(&linked.stmt, &selection.columns),
                ),
            );
        }

        if linked
            .stmt
            .assignments()
            .map(|a| a.is_empty())
            .unwrap_or(false)
        {
            if selection.returning.is_some() {
                return Some(self.insert_const(
                    vec![stmt::Value::empty_sparse_record()],
                    stmt::Type::list(stmt::Type::empty_sparse_record()),
                ));
            } else {
                return Some(self.insert_const(
                    Vec::<stmt::Value>::new(),
                    stmt::Type::list(stmt::Type::empty_sparse_record()),
                ));
            }
        }

        None
    }

    // ===== SQL execution =====

    fn plan_sql_execution(
        &mut self,
        linked: LinkedStatement,
        selection: &mut Selection,
    ) -> mir::NodeId {
        let LinkedStatement { mut stmt, inputs } = linked;

        if !selection.columns.is_empty() {
            stmt.set_returning(
                stmt::Expr::record(
                    selection
                        .columns
                        .iter()
                        .map(|expr_reference| stmt::Expr::from(*expr_reference)),
                )
                .into(),
            );
        }

        let input_args: Vec<_> = inputs
            .iter()
            .map(|input| self.planner.mir.ty(*input).clone())
            .collect();

        let ty = self.planner.engine.infer_ty(&stmt, &input_args[..]);

        let node = if stmt.condition().is_some() {
            if let stmt::Statement::Update(stmt) = stmt {
                assert!(stmt.returning.is_none(), "TODO: stmt={stmt:#?}");
                assert!(
                    selection.returning.is_none(),
                    "TODO: returning={:#?}",
                    selection.returning
                );

                if self.planner.engine.capability().cte_with_update {
                    mir::Operation::ExecStatement(Box::new(
                        self.plan_conditional_sql_query_as_cte(inputs, stmt, ty),
                    ))
                } else {
                    mir::Operation::ReadModifyWrite(Box::new(
                        self.plan_conditional_sql_query_as_rmw(inputs, stmt, ty),
                    ))
                }
            } else {
                todo!("stmt={stmt:#?}");
            }
        } else {
            debug_assert!(
                stmt.returning()
                    .and_then(|returning| returning.as_expr())
                    .map(|expr| expr.is_record())
                    .unwrap_or(true),
                "stmt={stmt:#?}"
            );
            // With SQL capability, we can just punt the details of execution to
            // the database's query planner.
            mir::Operation::ExecStatement(Box::new(mir::ExecStatement {
                inputs,
                stmt,
                ty,
                conditional_update_with_no_returning: false,
            }))
        };

        // With SQL capability, we can just punt the details of execution to
        // the database's query planner.
        debug_assert!(!self.did_take_deps);
        self.insert_mir_with_deps(node)
    }

    fn plan_conditional_sql_query_as_cte(
        &self,
        inputs: IndexSet<mir::NodeId>,
        stmt: stmt::Update,
        ty: stmt::Type,
    ) -> mir::ExecStatement {
        let Some(condition) = stmt.condition.expr else {
            panic!("conditional update without condition");
        };

        let Some(filter) = stmt.filter.expr else {
            panic!("conditional update without filter");
        };

        let stmt::UpdateTarget::Table(target) = stmt.target.clone() else {
            panic!("conditional update without table");
        };

        let mut ctes = vec![];

        // Select from update table without the update condition.
        ctes.push(stmt::Cte {
            query: stmt::Query::builder(target)
                .filter(filter.clone())
                .returning(vec![
                    stmt::Expr::count_star(),
                    stmt::FuncCount {
                        arg: None,
                        filter: Some(Box::new(condition)),
                    }
                    .into(),
                ])
                .build(),
        });

        let returning_len = match &stmt.returning {
            Some(stmt::Returning::Expr(expr)) => {
                let stmt::Expr::Record(expr_record) = expr else {
                    panic!("returning must be a record");
                };

                expr_record.fields.len()
            }
            Some(_) => todo!(),
            None => 0,
        };

        // The update statement. The update condition is expressed using the select above
        ctes.push(stmt::Cte {
            query: stmt::Query::new(stmt::Update {
                target: stmt.target,
                assignments: stmt.assignments,
                filter: stmt::Filter::new(stmt::Expr::and(
                    filter,
                    // SELECT found.count(*) = found.count(CONDITION) FROM found
                    stmt::Expr::stmt(stmt::Select {
                        source: stmt::TableRef::Cte {
                            nesting: 2,
                            index: 0,
                        }
                        .into(),
                        filter: true.into(),
                        returning: stmt::Returning::Expr(stmt::Expr::record_from_vec(vec![
                            stmt::Expr::eq(
                                stmt::ExprColumn {
                                    nesting: 0,
                                    table: 0,
                                    column: 0,
                                },
                                stmt::ExprColumn {
                                    nesting: 0,
                                    table: 0,
                                    column: 1,
                                },
                            ),
                        ])),
                    }),
                )),
                condition: Condition::default(),
                returning: Some(
                    stmt.returning
                        // TODO: hax
                        .unwrap_or_else(|| {
                            stmt::Returning::Expr(stmt::Expr::record_from_vec(vec![
                                stmt::Expr::from("hello"),
                            ]))
                        }),
                ),
            }),
        });

        let mut columns = vec![
            stmt::Expr::column(stmt::ExprColumn {
                nesting: 0,
                table: 0,
                column: 0,
            }),
            stmt::Expr::column(stmt::ExprColumn {
                nesting: 0,
                table: 0,
                column: 1,
            }),
        ];

        for i in 0..returning_len {
            columns.push(stmt::Expr::column(stmt::ExprColumn {
                nesting: 0,
                table: 1,
                column: i,
            }));
        }

        let stmt = stmt::Query::builder(stmt::Select {
            source: stmt::Source::table_with_joins(
                vec![
                    stmt::TableRef::Cte {
                        nesting: 0,
                        index: 0,
                    },
                    stmt::TableRef::Cte {
                        nesting: 0,
                        index: 1,
                    },
                ],
                stmt::TableWithJoins {
                    relation: stmt::TableFactor::Table(stmt::SourceTableId(0)),
                    joins: vec![stmt::Join {
                        table: stmt::SourceTableId(1),
                        constraint: stmt::JoinOp::Left(stmt::Expr::from(true)),
                    }],
                },
            ),
            filter: stmt::Filter::new(true),
            returning: stmt::Returning::Expr(stmt::Expr::record_from_vec(columns)),
        })
        .with(ctes)
        .build()
        .into();

        mir::ExecStatement {
            inputs,
            stmt,
            ty,
            conditional_update_with_no_returning: true,
        }
    }

    fn plan_conditional_sql_query_as_rmw(
        &mut self,
        inputs: IndexSet<mir::NodeId>,
        stmt: stmt::Update,
        ty: stmt::Type,
    ) -> mir::ReadModifyWrite {
        // For now, no returning supported
        assert!(stmt.returning.is_none(), "TODO: support returning");

        let Some(condition) = stmt.condition.expr else {
            panic!("conditional update without condition");
        };

        let Some(filter) = stmt.filter.expr else {
            panic!("conditional update without filter");
        };

        let stmt::UpdateTarget::Table(target) = stmt.target.clone() else {
            panic!("conditional update without table");
        };

        // Neither SQLite nor MySQL support CTE with update. We should transform
        // the conditional update into a transaction with checks between.

        let read = stmt::Query::builder(target)
            .filter(filter.clone())
            .returning(vec![
                stmt::Expr::count_star(),
                stmt::FuncCount {
                    arg: None,
                    filter: Some(Box::new(condition)),
                }
                .into(),
            ])
            .locks(if self.planner.engine.capability().select_for_update {
                vec![stmt::Lock::Update]
            } else {
                vec![]
            })
            .build();

        let write = stmt::Update {
            target: stmt.target,
            assignments: stmt.assignments,
            filter: stmt::Filter::new(filter),
            condition: stmt::Condition::default(),
            returning: None,
        };

        mir::ReadModifyWrite {
            inputs,
            read,
            write: write.into(),
            ty,
        }
    }

    // ===== NoSQL execution =====

    fn plan_nosql_execution(
        &mut self,
        linked: LinkedStatement,
        selection: &mut Selection,
        ref_source: Option<stmt::ExprArg>,
    ) -> mir::NodeId {
        // Without SQL capability, we have to plan the execution of the
        // statement based on available indices.
        let mut index_plan = self.planner.engine.plan_index_path(&linked.stmt);
        let pk_keys = self.try_build_pk_keys(&linked, &index_plan, ref_source);

        let post_filter =
            self.prepare_post_filter(&linked, &mut index_plan, pk_keys.is_some(), selection);

        // Type of the final record.
        let ty = if selection.columns.is_empty() {
            stmt::Type::Unit
        } else {
            self.planner
                .engine
                .infer_record_list_ty(&linked.stmt, &selection.columns)
        };

        let node_id = if index_plan.index.primary_key {
            self.plan_primary_key_execution(
                linked,
                &mut index_plan,
                pk_keys,
                ref_source,
                selection,
                &ty,
            )
        } else {
            self.plan_secondary_index_execution(linked, &mut index_plan, selection, &ty)
        };

        self.apply_post_filter(node_id, post_filter, ty)
    }

    fn plan_primary_key_execution(
        &mut self,
        linked: LinkedStatement,
        index_plan: &mut index::IndexPlan,
        pk_keys: Option<eval::Func>,
        ref_source: Option<stmt::ExprArg>,
        selection: &Selection,
        ty: &stmt::Type,
    ) -> mir::NodeId {
        if let Some(keys) = pk_keys {
            let get_by_key_input = self.build_get_by_key_input(
                keys,
                &linked,
                ref_source,
                self.index_key_ty(index_plan),
            );

            self.build_key_operation(&linked.stmt, index_plan, get_by_key_input, selection, ty)
        } else {
            let input = if linked.inputs.is_empty() {
                None
            } else if linked.inputs.len() == 1 {
                Some(linked.inputs[0])
            } else {
                todo!()
            };

            self.insert_mir_with_deps(mir::QueryPk {
                input,
                table: index_plan.table_id(),
                columns: selection.columns.clone(),
                pk_filter: index_plan.index_filter.take(),
                row_filter: index_plan.result_filter.take(),
                ty: ty.clone(),
            })
        }
    }

    fn plan_secondary_index_execution(
        &mut self,
        linked: LinkedStatement,
        index_plan: &mut index::IndexPlan,
        selection: &Selection,
        ty: &stmt::Type,
    ) -> mir::NodeId {
        assert!(index_plan.post_filter.is_none(), "TODO");
        assert!(
            linked.inputs.len() <= 1,
            "TODO: inputs={:#?}",
            linked.inputs
        );

        let index_key_ty = self.index_key_ty(index_plan);

        let LinkedStatement { stmt, inputs } = linked;

        let get_by_key_input = self.insert_mir_with_deps(mir::FindPkByIndex {
            inputs,
            table: index_plan.index.on,
            index: index_plan.index.id,
            filter: index_plan.index_filter.take(),
            ty: index_key_ty,
        });

        self.build_key_operation(&stmt, index_plan, get_by_key_input, selection, ty)
    }

    fn try_build_pk_keys(
        &mut self,
        linked: &LinkedStatement,
        index_plan: &index::IndexPlan,
        ref_source: Option<stmt::ExprArg>,
    ) -> Option<eval::Func> {
        // If the query can be reduced to fetching rows using a set of
        // primary-key keys, then `pk_keys` will be set to `Some(<keys>)`.
        if !index_plan.index.primary_key {
            return None;
        }

        let pk_keys_project_args = if ref_source.is_some() {
            assert_eq!(linked.inputs.len(), 1, "TODO");
            let ty = self.planner.mir[linked.inputs[0]].ty();
            vec![ty.unwrap_list_ref().clone()]
        } else {
            linked
                .inputs
                .iter()
                .map(|node_id| self.planner.mir[node_id].ty().clone())
                .collect()
        };

        // If using the primary key to find rows, try to convert the
        // filter expression to a set of primary-key keys.
        let cx = self.planner.engine.expr_cx_for(&linked.stmt);
        self.planner.engine.try_build_key_filter(
            cx,
            index_plan.index,
            &index_plan.index_filter,
            pk_keys_project_args,
        )
    }

    fn prepare_post_filter(
        &mut self,
        linked: &LinkedStatement,
        index_plan: &mut index::IndexPlan,
        has_pk_keys: bool,
        selection: &mut Selection,
    ) -> Option<stmt::Expr> {
        let mut post_filter = index_plan.post_filter.clone();

        // If fetching rows using GetByKey, some databases do not support
        // applying additional filters to the rows before returning results.
        // In this case, the result_filter needs to be applied in-memory.
        if linked.stmt.is_query() && (has_pk_keys || !index_plan.index.primary_key) {
            if let Some(result_filter) = index_plan.result_filter.take() {
                post_filter = Some(match post_filter {
                    Some(post_filter) => stmt::Expr::and(result_filter, post_filter),
                    None => result_filter,
                });
            }
        }

        debug_assert!(
            post_filter.is_none() || linked.stmt.is_query(),
            "stmt={:#?}; post_filter={post_filter:#?}",
            linked.stmt
        );

        // Make sure we are including columns needed to apply the post filter
        if let Some(post_filter) = &mut post_filter {
            visit_mut::for_each_expr_mut(post_filter, |expr| match expr {
                stmt::Expr::Reference(expr_reference) => {
                    let (index, _) = selection.columns.insert_full(*expr_reference);
                    *expr = stmt::Expr::arg_project(0, [index]);
                }
                stmt::Expr::Arg(_) => todo!("expr={expr:#?}"),
                _ => {}
            });
        }

        post_filter
    }

    fn apply_post_filter(
        &mut self,
        mut node_id: mir::NodeId,
        post_filter: Option<stmt::Expr>,
        ty: stmt::Type,
    ) -> mir::NodeId {
        // If there is a post filter, we need to apply a filter step on the returned rows.
        if let Some(post_filter) = post_filter {
            let item_ty = ty.unwrap_list_ref();
            node_id = self.planner.mir.insert(mir::Filter {
                input: node_id,
                filter: eval::Func::from_stmt(post_filter, vec![item_ty.clone()]),
                ty,
            });
        }

        node_id
    }

    fn build_get_by_key_input(
        &mut self,
        keys: eval::Func,
        linked: &LinkedStatement,
        ref_source: Option<stmt::ExprArg>,
        index_key_ty: stmt::Type,
    ) -> mir::NodeId {
        if keys.is_const() {
            self.insert_const(keys.eval_const(), index_key_ty)
        } else if keys.is_identity() {
            debug_assert_eq!(1, linked.inputs.len(), "TODO");
            linked.inputs[0]
        } else {
            debug_assert!(ref_source.is_some(), "TODO");
            let ty = stmt::Type::list(keys.ret.clone());
            // Gotta project
            self.planner.mir.insert(mir::Project {
                input: linked.inputs[0],
                projection: keys,
                ty,
            })
        }
    }

    fn build_key_operation(
        &mut self,
        stmt: &stmt::Statement,
        index_plan: &mut index::IndexPlan,
        get_by_key_input: mir::NodeId,
        selection: &Selection,
        ty: &stmt::Type,
    ) -> mir::NodeId {
        match stmt {
            stmt::Statement::Query(_) => {
                debug_assert!(ty.is_list());
                self.insert_mir_with_deps(mir::GetByKey {
                    input: get_by_key_input,
                    table: index_plan.table_id(),
                    columns: selection.columns.clone(),
                    ty: ty.clone(),
                })
            }
            stmt::Statement::Delete(_) => {
                debug_assert!(
                    ty.is_unit(),
                    "stmt={stmt:#?}; returning={:#?}; ty={ty:#?}",
                    selection.returning
                );
                self.insert_mir_with_deps(mir::DeleteByKey {
                    input: get_by_key_input,
                    table: index_plan.table_id(),
                    filter: index_plan.result_filter.take(),
                    ty: stmt::Type::Unit,
                })
            }
            stmt::Statement::Update(update_stmt) => self.insert_mir_with_deps(mir::UpdateByKey {
                input: get_by_key_input,
                table: index_plan.table_id(),
                assignments: update_stmt.assignments.clone(),
                filter: index_plan.result_filter.take(),
                condition: update_stmt.condition.expr.clone(),
                ty: ty.clone(),
            }),
            _ => todo!("stmt={stmt:#?}"),
        }
    }

    // ===== Finalization helpers =====

    fn process_back_ref_projections(
        &mut self,
        exec_stmt_node_id: mir::NodeId,
        selection: &Selection,
    ) {
        for back_ref in self.stmt_info.back_refs.values() {
            let projection = stmt::Expr::record(back_ref.exprs.iter().map(|expr_reference| {
                let index = selection.columns.get_index_of(expr_reference).unwrap();
                stmt::Expr::arg_project(0, [index])
            }));

            let arg_ty = self.planner.mir[exec_stmt_node_id]
                .ty()
                .unwrap_list_ref()
                .clone();
            let projection = eval::Func::from_stmt(projection, vec![arg_ty]);
            let ty = stmt::Type::list(projection.ret.clone());

            let project_node_id = self.planner.mir.insert(mir::Project {
                input: exec_stmt_node_id,
                projection,
                ty,
            });
            back_ref.node_id.set(Some(project_node_id));
        }
    }

    fn plan_child_statements(&mut self) {
        for arg in &self.stmt_info.args {
            let hir::Arg::Sub { stmt_id, .. } = arg else {
                continue;
            };

            self.planner.plan_statement(*stmt_id);
        }
    }

    fn plan_output_node(
        &mut self,
        exec_stmt_node_id: mir::NodeId,
        returning: Option<stmt::Returning>,
        single: bool,
        ref_source: Option<stmt::ExprArg>,
    ) -> mir::NodeId {
        // First check for nested merge
        if let Some(node_id) = self.planner.plan_nested_merge(self.stmt_id) {
            return node_id;
        }

        // Then handle returning clause
        if let Some(returning) = returning {
            debug_assert!(
                !single || ref_source.is_some(),
                "TODO: single queries not supported here"
            );

            match returning {
                stmt::Returning::Value(returning) => {
                    let ty = returning.infer_ty();

                    let stmt::Value::List(rows) = returning else {
                        todo!(
                            "unexpected returning type; returning={returning:#?}; stmt={:#?}",
                            self.stmt_info.stmt
                        )
                    };

                    self.planner
                        .mir
                        .insert_with_deps(mir::Const { value: rows, ty }, [exec_stmt_node_id])
                }
                stmt::Returning::Expr(returning) => {
                    let arg_ty = match self.planner.mir[exec_stmt_node_id].ty() {
                        stmt::Type::List(ty) => vec![(**ty).clone()],
                        stmt::Type::Unit => vec![],
                        _ => todo!(),
                    };

                    let projection = eval::Func::from_stmt(returning, arg_ty);
                    let ty = stmt::Type::list(projection.ret.clone());

                    let node = mir::Project {
                        input: exec_stmt_node_id,
                        projection,
                        ty,
                    };

                    // Plan the final projection to handle the returning clause.
                    self.insert_mir_with_deps(node)
                }
                returning => panic!("unexpected `stmt::Returning` kind; returning={returning:#?}"),
            }
        } else {
            if let Some(dependencies) = self.take_dependencies() {
                self.planner.mir[exec_stmt_node_id]
                    .deps
                    .extend(dependencies);
            }

            exec_stmt_node_id
        }
    }

    // ===== MIR/utility helpers =====

    #[track_caller]
    fn insert_const(&mut self, value: impl Into<stmt::Value>, ty: stmt::Type) -> mir::NodeId {
        let value = value.into();

        // Type check
        debug_assert!(
            ty.is_list(),
            "const types must be of type `stmt::Type::List`"
        );
        debug_assert!(
            value.is_a(&ty),
            "const type mismatch; expected={ty:#?}; actual={value:#?}",
        );

        self.planner.mir.insert(mir::Const {
            value: value.unwrap_list(),
            ty,
        })
    }

    fn insert_mir_with_deps(&mut self, node: impl Into<mir::Node>) -> mir::NodeId {
        if let Some(dependencies) = self.take_dependencies() {
            self.planner.mir.insert_with_deps(node, dependencies)
        } else {
            self.planner.mir.insert(node)
        }
    }

    fn take_dependencies(&mut self) -> Option<impl Iterator<Item = mir::NodeId> + 'a> {
        if !self.did_take_deps {
            self.did_take_deps = true;
            Some(self.stmt_info.dependent_operations(self.planner.hir))
        } else {
            None
        }
    }

    fn index_key_ty(&self, index_plan: &IndexPlan) -> stmt::Type {
        // Type of the index key. Value for single index keys, record for
        // composite.
        stmt::Type::list(self.planner.engine.index_key_record_ty(index_plan.index))
    }
}
